from ppyolo import get_object_position_ppyolo
from deep_sort.detection import *
from deep_sort import preprocessing
from deep_sort.tracker import Tracker
from deep_sort import nn_matching
from extractor_new import Extractor
from collections import deque
from ResNet_ReID_paddle.ResNet50_ReID import Net
import requests
import socket
import paddle
import time
import cv2
import threading
#import matplotlib.pyplot as plt


"""# 2初始化套接字
tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

# 3建立链接  要传入链接的服务器ip和port
tcp_socket.connect(('134.122.189.201', 5555))"""

"""lock = threading.Lock()"""

palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)

def color_for_labels(label):
    color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
    return tuple(color)

model_path = 'ResNet_ReID_paddle/final.pdparams'
net = Net(num_classes=751, reid=True)

paddle.set_device("gpu")
static_dict = paddle.load(model_path)
net.set_state_dict(static_dict)
net.eval()

extr = Extractor(net)
cap = cv2.VideoCapture('test.mp4')

max_cosine_distance = 0.45  # 余弦距离的控制阈值
nn_budget = 100
nms_max_overlap = 0.4

metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
tracker = Tracker(metric)
k = 1

"""def craw(id, x, y, w, h):
    with lock:
        payload = {'ID':str(id), "x":str(int(x)), "y":str(int(y)), "w":str(int(w)), "h":str(int(h))}
        res = requests.get("http://43.132.244.59:5555/yolo", params=payload) #http://134.122.189.201:5555/yolo
        #time.sleep(1)"""

while True:
    ret, img = cap.read()
    #plt.imshow(img)
    #plt.show()
    if ret == True:  
        cv2.namedWindow("img", cv2.WINDOW_NORMAL)
        results, confidences, features = get_object_position_ppyolo(img, 0.61, extr)
        detections = [Detection(bbox, confidence, feature) for bbox, confidence, feature in zip(results, confidences, features)]
        boxes = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])
        indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
        detections = [detections[i] for i in indices]

        # Call the tracker
        tracker.predict()
        tracker.update(detections)

        h = int(0)
        indexIDs = []
        c = []
        for det in detections:
            bbox = det.to_tlbr()
            cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 255, 255), 1)
        for track in tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            indexIDs.append(int(track.track_id))
            bbox = track.to_tlbr()
            x = bbox[0] 
            y = bbox[1] 
            w = bbox[2] - bbox[0] 
            h = bbox[3] - bbox[1] 
            id = int(track.track_id)
            #tcp_socket.send(str(id).encode())
            """thread = threading.Thread(target=craw, args=(int(track.track_id), x, y, w, h))
            thread.start()"""
            color = color_for_labels(id)
            t_size = cv2.getTextSize(str(track.track_id), cv2.FONT_HERSHEY_PLAIN, 2, 2)[0]
            cv2.rectangle(img, (int(x), int(y)), (int(x+w), int(y+h)), color, 1)
            cv2.rectangle(img, (int(x), int(y)), (int(x + t_size[0] + 3), int(y + t_size[1] + 4)), color, -1)
            cv2.putText(img, str(id), (int(x), int(y +t_size[1] + 4)), cv2.FONT_HERSHEY_PLAIN, 2, (255, 255, 255), 4)
        count = len(results)
        cv2.putText(img, "Person be detected: " + str(count), (int(20), int(20)), 0, 5e-3 * 200, (0, 255, 0), 1)
        cv2.imshow('img', img)
        cv2.waitKey(1)
        print('frame%d'%k)
        k += 1
    else:
        cap.release()
        cv2.destroyAllWindows()
        print('video wrong')
        break

#tcp_socket.close()